{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ee3d68e4",
   "metadata": {},
   "source": [
    "# Compare MUGEN's Video VQVAE with TorchMultimodal's\n",
    "\n",
    "This notebook loads the public MUGEN checkpoint for Video VQVAE, remaps the state_dict, and loads it into TorchMultimodal's Video VQVAE to ensure the outputs match. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5af9d001",
   "metadata": {},
   "source": [
    "### Set directories\n",
    "\n",
    "Replace these with your local directories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "071c8b48",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_dir = '/Users/rafiayub/checkpoints/'\n",
    "repo_dir = '/Users/rafiayub/mugen/'\n",
    "home_dir = '/Users/rafiayub/'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3a0f19f",
   "metadata": {},
   "source": [
    "### Clone MUGEN's repo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83812502",
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/mugen-org/MUGEN_baseline.git $repo_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07757cfa",
   "metadata": {},
   "source": [
    "### Download and unzip checkpoints\n",
    "\n",
    "This will take some time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d41a0c86",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://dl.noahmt.com/creativity/data/MUGEN_release/checkpoints.zip -P $checkpoint_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01d9638a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Unzip checkpoints\n",
    "zip_location = os.path.join(checkpoint_dir, 'checkpoints.zip')\n",
    "!unzip $zip_location -d $checkpoint_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f06c8938",
   "metadata": {},
   "source": [
    "### Load checkpoint into MUGEN model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f3e74b3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.append(home_dir)\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import mugen\n",
    "\n",
    "ckpt = torch.load(\n",
    "    os.path.join(checkpoint_dir, 'generation/video_vqvae/L32/epoch=54-step=599999.ckpt'), \n",
    "    map_location=torch.device('cpu')\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ea6d13e",
   "metadata": {},
   "source": [
    "The arguments are taken from MUGEN's training scripts found at: https://github.com/mugen-org/MUGEN_baseline/blob/main/generation/experiments/vqvae/VideoVQVAE_L32.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f81bea2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Namespace:\n",
    "    def __init__(self, **kwargs):\n",
    "        self.__dict__.update(kwargs)\n",
    "\n",
    "\n",
    "vqvae_args=Namespace(\n",
    "    embedding_dim=256,\n",
    "    n_codes=2048,\n",
    "    n_hiddens=240,\n",
    "    n_res_layers=4,\n",
    "    lr=0.0003,\n",
    "    downsample=(4, 32, 32),\n",
    "    kernel_size=3,\n",
    "    sequence_length=16,\n",
    "    resolution=256,\n",
    ")\n",
    "vv_mugen = mugen.VQVAE(vqvae_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fbdcf1f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vv_mugen.load_state_dict(ckpt['state_dict'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6bfb325",
   "metadata": {},
   "source": [
    "### Create TorchMultimodal's Video VQVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "74e6bd54",
   "metadata": {},
   "outputs": [],
   "source": [
    "from examples.mugen.generation.video_vqvae import video_vqvae_mugen\n",
    "\n",
    "vv_torchmm = video_vqvae_mugen(pretrained_model_key=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e612d831",
   "metadata": {},
   "source": [
    "### Remap MUGEN's state_dict and load into new model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5f4d4774",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def map_state_dict(state_dict):\n",
    "    mapped_state_dict = {}\n",
    "    dim_map = {'w': '2', 'h': '1', 't': '0'}\n",
    "    layer_map = {'w_qs': 'query', 'w_ks': 'key', 'w_vs': 'value', 'fc': 'output'}\n",
    "    for param, val in state_dict.items():\n",
    "        new_param = param\n",
    "        res = re.search('encoder.convs.', param)\n",
    "        if res:\n",
    "            idx = res.end()\n",
    "            layer_id = int(param[idx])\n",
    "            new_param = param[:idx] + str(layer_id * 2) + param[idx+1:]\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "        res = re.search('encoder.conv_last', param)\n",
    "        if res:\n",
    "            idx = res.start() + len('encoder.')\n",
    "            new_param = param[:idx] + 'convs.10' + param[res.end():]\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "        res = re.search('attn_[w,h,t]\\..*\\.', param)\n",
    "        if res:\n",
    "            dim = param[res.start()+5]\n",
    "            new_dim = dim_map[dim]\n",
    "            layer = param[res.start()+7:res.end()-1]\n",
    "            new_layer = layer_map[layer]\n",
    "            new_param = param[:res.start()] + 'mha_attns.' + new_dim + '.' + new_layer + '.' + param[res.end():]\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "        res = re.search('pre_vq_conv', param)\n",
    "        if res:\n",
    "            new_param = 'encoder.conv_out' + param[res.end():]\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "        res = re.search('post_vq_conv', param)\n",
    "        if res:\n",
    "            new_param = 'decoder.conv_in' + param[res.end():]\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "        res = re.search('decoder.convts.', param)\n",
    "        if res:\n",
    "            idx = res.end()\n",
    "            layer_id = int(param[idx])\n",
    "            new_param = param[:idx] + str(layer_id * 2) + param[idx+1:]\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "        if param == 'codebook.N':\n",
    "            new_param = 'codebook.code_usage'\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "        if param == 'codebook.z_avg':\n",
    "            new_param = 'codebook.code_avg'\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "        if param == 'codebook.embeddings':\n",
    "            new_param = 'codebook.embedding'\n",
    "            mapped_state_dict[new_param] = val\n",
    "            continue\n",
    "            \n",
    "        mapped_state_dict[new_param] = val\n",
    "        \n",
    "    return mapped_state_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "38234858",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_state_dict = map_state_dict(ckpt['state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e160fb51",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vv_torchmm.load_state_dict(new_state_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46d58eb7",
   "metadata": {},
   "source": [
    "### Compare outputs with a random input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3c85cdd3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Max difference between outputs: 3.0875205993652344e-05\n",
      "Mean difference between outputs: 1.7353995929170196e-07\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(4)\n",
    "video = torch.randn(1,3,32,256,256) # b, c, t, h, w\n",
    "\n",
    "vv_mugen.eval()\n",
    "vv_torchmm.eval()\n",
    "\n",
    "loss, x_recon, codebook_output = vv_mugen(video)\n",
    "output = vv_torchmm(video)\n",
    "\n",
    "diff = abs(output.decoded - x_recon)\n",
    "print(f'Max difference between outputs: {torch.max(diff).item()}')\n",
    "print(f'Mean difference between outputs: {torch.mean(diff).item()}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa78569e",
   "metadata": {},
   "source": [
    "### Save mapped checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "48651d44",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = '/Users/rafiayub/checkpoints/generation/video_vqvae/mugen_video_vqvae_L32.pt'\n",
    "torch.save(new_state_dict, save_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
